from typing import Any, overload
from typing_extensions import TypeAlias

import numpy as np
import tensorflow as tf
from tensorflow._aliases import FloatArray, IntArray

# The alias below is not fully accurate, since TensorFlow casts the inputs, they have some additional
# requirements. For example y needs to be castable into x's dtype. Moreover, x and y cannot both be booleans.
# Properly typing the bitwise functions would be overly complicated and unlikely to provide much benefits
# since most people use Tensors, it was therefore not done.
_BitwiseCompatible: TypeAlias = tf.Tensor | int | FloatArray | IntArray | np.number[Any]

@overload
def bitwise_and(x: _BitwiseCompatible, y: _BitwiseCompatible, name: str | None = None) -> tf.Tensor: ...
@overload
def bitwise_and(x: tf.RaggedTensor, y: tf.RaggedTensor, name: str | None = None) -> tf.RaggedTensor: ...
@overload
def bitwise_or(x: _BitwiseCompatible, y: _BitwiseCompatible, name: str | None = None) -> tf.Tensor: ...
@overload
def bitwise_or(x: tf.RaggedTensor, y: tf.RaggedTensor, name: str | None = None) -> tf.RaggedTensor: ...
@overload
def bitwise_xor(x: _BitwiseCompatible, y: _BitwiseCompatible, name: str | None = None) -> tf.Tensor: ...
@overload
def bitwise_xor(x: tf.RaggedTensor, y: tf.RaggedTensor, name: str | None = None) -> tf.RaggedTensor: ...
@overload
def invert(x: _BitwiseCompatible, name: str | None = None) -> tf.Tensor: ...
@overload
def invert(x: tf.RaggedTensor, name: str | None = None) -> tf.RaggedTensor: ...
@overload
def left_shift(x: _BitwiseCompatible, y: _BitwiseCompatible, name: str | None = None) -> tf.Tensor: ...
@overload
def left_shift(x: tf.RaggedTensor, y: tf.RaggedTensor, name: str | None = None) -> tf.RaggedTensor: ...
@overload
def right_shift(x: _BitwiseCompatible, y: _BitwiseCompatible, name: str | None = None) -> tf.Tensor: ...
@overload
def right_shift(x: tf.RaggedTensor, y: tf.RaggedTensor, name: str | None = None) -> tf.RaggedTensor: ...
